-
Notifications
You must be signed in to change notification settings - Fork 5.7k
在scaled_dot_product_attention函数中加入bool mask #72927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
modified: python/paddle/nn/functional/flash_attention.py modified: test/legacy_test/test_flash_attention.py
你的PR提交成功,感谢你对开源项目的贡献! |
modified: test/legacy_test/test_flash_attention.py
modified: test/legacy_test/test_flash_attention.py new file: test/legacy_test/test_scaled_dot_product_attention.py
modified: test/legacy_test/test_scaled_dot_product_attention.py
modified: test/legacy_test/test_scaled_dot_product_attention.py
modified: python/paddle/nn/functional/flash_attention.py modified: test/legacy_test/test_scaled_dot_product_attention.py
modified: test/legacy_test/test_scaled_dot_product_attention.py
@@ -1272,6 +1284,7 @@ def scaled_dot_product_attention( | |||
sdp_func_name = _select_sdp_for_sdpa( | |||
query, key, attn_mask, dropout_p, is_causal | |||
) | |||
attn_mask = _convert_bool_mask_to_float(attn_mask, query.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数的逻辑比较简单,是不是可以直接写到这里来
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯,收到,已修改
out_ = attention_naive_with_mask(q_, k_, v_, m) | ||
out.backward() | ||
out_.backward() | ||
np.testing.assert_allclose(out.numpy(), out_, rtol=5e-03, atol=1e-03) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
你本地在PaConvert里的sdpa的单测里加一下attn_mask为bool的测试例子,测试一下计算结果是否和pytorch 一致。附一下paconvert测试结果。
然后映射文档也记得修改下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯,已修改。之前用的是python3.8虚拟环境,升级python3.9重装环境花费了一些时间。
PaddlePaddle/PaConvert#586
modified: python/paddle/nn/functional/flash_attention.py
modified: python/paddle/nn/functional/flash_attention.py
modified: test/legacy_test/test_scaled_dot_product_attention.py
modified: test/legacy_test/test_scaled_dot_product_attention.py
在docker paddlepaddle/paddle:latest-dev-cuda11.8-cudnn8.6-trt8.5-gcc82中用一下命令编译 cmake .. -DPY_VERSION=3.9 -DWITH_GPU=ON -DWITH_TENSORRT=ON -DWITH_TESTING=ON 测试test_quant_linear_fuse_pass可以通过 python ../test/ir/inference/test_quant_linear_fuse_pass.py
grep: warning: GREP_OPTIONS is deprecated; please use an alias or script
RuntimeError: module compiled against ABI version 0x1000009 but this version of numpy is 0x2000000
/paddle/test/ir/inference/auto_scan_test.py:61: HypothesisDeprecationWarning: `Healthcheck.all()` is deprecated; use `list(HealthCheck)` instead.
The `hypothesis codemod` command-line tool can automatically refactor your code to fix this warning.
suppress_health_check=hypothesis.HealthCheck.all(),
/paddle/test/ir/inference/auto_scan_test.py:70: HypothesisDeprecationWarning: `Healthcheck.all()` is deprecated; use `list(HealthCheck)` instead.
The `hypothesis codemod` command-line tool can automatically refactor your code to fix this warning.
suppress_health_check=hypothesis.HealthCheck.all(),
/paddle/test/ir/inference/auto_scan_test.py:461: HypothesisDeprecationWarning: `Healthcheck.all()` is deprecated; use `list(HealthCheck)` instead.
The `hypothesis codemod` command-line tool can automatically refactor your code to fix this warning.
suppress_health_check=hypothesis.HealthCheck.all(),
Sun Jun 01 09:14:22-INFO: Start to running test of <class '__main__.TestQuantLinearFusePass'>
I0601 09:14:21.751210 29451 program_interpreter.cc:257] New Executor is Running.
Sun Jun 01 09:14:23-INFO: Number of Invalid Programs: 0
Sun Jun 01 09:14:23-INFO: Number of Ran Programs: 30
Sun Jun 01 09:14:23-INFO: Number of Ignore Tests: 0
.
----------------------------------------------------------------------
Ran 1 test in 1.812s
OK |
看下CI没过 |
嗯,可否帮忙看一下test_quant_linear_fuse_pass这个测试,修改的函数应该对这个测试没有影响。而且之前应该还有一个CI跑了几次没跑过,今天重新跑了之后跑过了。
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
User Experience
PR Types
Improvements
Description
在scaled_dot_product_attention函数中加入bool mask